from typing import Tuple, Union, Dict
import numpy as np

import torch
import sys, os
from tianshou.utils.net.common import get_dict_state_decorator

import gymnasium as gym

from Causal import Dynamics, GroundTruthGraph, DynamicsGrad, DynamicsAC

from Initializers.init_utils import AttrDict
from State.extractor import Extractor
from Policy.Reward.goal_reward import GoalReward
from Policy.Reward.distance_reward import GoalDistanceReward
from Policy.goal_policy import GoalPolicy
from Policy.Reward.rew_term_trunc_manager import RewardTermTruncManager
from Policy.hindsight_filter import ActionGraphHindsightFilter, NonPassiveGraphHindsightFilter, AllGraphHindsightFilter, ControlHindsightFilter
from Initializers.flat_obs_normalize import FlatNormalization

from Networks import GraphEncoding, DiaynDiscriminator
from Initializers.policy import initialize_policy


def initialize_dynamics(config: AttrDict, env: gym.Env, extractor: Extractor, norm: FlatNormalization, wdb_run=None) -> Dynamics:
    if config.dynamics.type == "gt":
        config.train.dynamics_warmup_step = 0
        dynamics = GroundTruthGraph(env, extractor)
    elif config.dynamics.type == "ac":
        config.train.dynamics_warmup_step = config.train.dynamics_warmup_step
        config.train.init_random_step = max(config.train.init_random_step, config.train.dynamics_pretrain_step)
        dynamics = DynamicsAC(env, extractor, norm, config, wdb_run=None)
    else:
        raise NotImplementedError

    return dynamics

def initialize_graph_encoding(config: AttrDict, device: torch.device):
    graph_encoding = None

    upper_config = config.policy
    if upper_config.graph_action_space == "graph_encoding":
        graph_encoding = GraphEncoding(config).to(device)
    return graph_encoding


def initialize_gc_policy(
        config: AttrDict,
        env: gym.Env,
        extractor: Extractor,
        dynamics: Dynamics,
        net_args: Dict= None,
) -> GoalPolicy:
    assert isinstance(env.observation_space, gym.spaces.Box) and len(env.observation_space.shape) == 1

    state_shape = {"observation": env.observation_space.shape[0], "desired_goal": env.goal_space.shape[0]}
    keys = ["observation", "desired_goal"]
    if config.policy.use_reached_graph_counter:
        state_shape["reached_graph_counter"] = 1
        keys.append("reached_graph_counter")
    # decides what keys go into the policy from the observation (not the acheived goal)
    dict_state_dec, flat_state_shape = get_dict_state_decorator(state_shape, keys)

    device = config.device
    action_space = env.action_space

    rt = config.policy.reward_type 
    # TODO: we could add in other rewards in the future (distributional, etc)
    if config.policy.reward.reward_type == "goal":
        rtt = GoalReward(config=config)
    elif config.policy.reward.reward_type == "distgoal":
        rtt = GoalDistanceReward(config=config)


    rewtermdone = RewardTermTruncManager([rtt])


    if isinstance(action_space, gym.spaces.Discrete):
        algo = config.policy.discrete_algo
    elif isinstance(action_space, gym.spaces.Box):
        algo = config.policy.continuous_algo
    else:
        raise NotImplementedError(f"unknown action space type: {type(action_space)}")

    ts_policy = initialize_policy(algo, flat_state_shape, action_space, config.policy, dict_state_dec, device, net_args= net_args)
    
    passive_graph = np.eye(config.num_factors)
    passive_graph = np.concatenate([np.array([[0] * config.num_factors]).T, passive_graph], axis=-1)
    passive_graph[0,0] = 1
    target_idx_diff = -2 if config.data.her.filter.target_idx < 0 else 1
    if config.data.her.filter.form == "non_passive":
        filter = NonPassiveGraphHindsightFilter(passive_graph, config.data.her.filter.min_non_passive, -1, -1) # TODO: target indices not defined
    elif config.data.her.filter.form == "target_non_passive":
        if config.data.her.filter.target_idx < 0: use_target_graph_idx = config.num_factors + config.data.her.filter.target_idx
        filter = NonPassiveGraphHindsightFilter(passive_graph, config.data.her.filter.min_non_passive, use_target_graph_idx + 1, use_target_graph_idx) # TODO: target indices not defined
    elif config.data.her.filter.form == "control":
        filter = ControlHindsightFilter(passive_graph, config.data.her.filter.min_non_passive, config.data.her.filter.target_idx + target_idx_diff, config.data.her.filter.target_idx, 1) # assumes the controllable object is always at position 1
    elif config.data.her.filter.form == "action_graph":
        filter = ActionGraphHindsightFilter(0, config.data.her.filter.target_idx + target_idx_diff, config.data.her.filter.target_idx, passive_graph, config.data.her.filter.min_non_passive)
    elif config.data.her.filter.form == "all":
        filter = AllGraphHindsightFilter(passive_graph, config.data.her.filter.min_non_passive, config.data.her.filter.target_idx + target_idx_diff, config.data.her.filter.target_idx)
    else:
        raise NotImplementedError("filter form unknown")
    
    policy = GoalPolicy(ts_policy, dynamics, rewtermdone, extractor, action_space, config, hindsight_filter=filter)
    config.num_policies = 1
    return policy, rewtermdone

def initialize_gen_net_args(config, extractor):
    if len(config.policy.net_config_path):
        # slightly redundant if we load the ac path in ac_dynamics also
        if os.path.join(sys.path[0],"Causal", "ac_infer") not in sys.path: sys.path.append(os.path.join(sys.path[0],"Causal", "ac_infer"))
        from Causal.ac_infer.Hyperparam.read_config import read_config
        from State.utils import ObjDict
        all_args = read_config(config.policy.net_config_path)
        net_args = all_args[config.policy.net_config_name]
        # we need to assign the appropriate factored components
        net_args.factor = ObjDict()
        net_args.factor.single_obj_dim = extractor.longest
        net_args.factor.first_obj_dim = extractor.num_factors * extractor.longest
        net_args.factor.object_dim = extractor.longest
        net_args.factor.name_idx = -1
        net_args.factor.query_aggregate = True
        net_args.factor_net.append_keys = False # there is no key separation
        net_args.factor.num_keys = extractor.num_factors
        net_args.factor.num_queries = extractor.num_factors
        net_args.factor.start_dim = 0
        net_args.aggregate_final = True # policy space is never factor dependent
        net_args.gpu = config.cuda_id
        net_args.extractor = extractor # needs the extractor to pad the state
        return net_args
    return None


def initialize_models(config: AttrDict, env: gym.Env, norm: FlatNormalization, wdb_run = None):
    """
    Initializes the hierarchical model and interaction model based on the environment
    This includes initializing components (such as TS policies) which are sub-modules 
    of each component
    """
    device = config.device
    # extractors handle converting a 1D observation b x N into an b x n x k, where n is the number of factors
    extractor = Extractor(env)

    # initializes the arguments for the networks that go inside of the policy.
    # the networks that require this are inside Causal/ac_infer/Network/net_types.py
    # since this only initializes arguments, it use the num_factored and extractor.longest
    net_args = initialize_gen_net_args(config, extractor)

    # Initializes the actual cause inference model. Since this model handles its own logging
    # it uses wdb. The actual AC parameters are defined through the yaml file decsribed by
    # config.dynamics_config_path
    dynamics = initialize_dynamics(config, env, extractor, norm, wdb_run).to(device) # graph compute code
    
    # It's not obivous how important this is, but it converts the graph of AC binaries
    # into an embedding space
    graph_encoding = initialize_graph_encoding(config, device) # encodes graph to policy space if used

    # creates a tianshou policy with goal conditioning. This also includes the Reward, terminate truncate function
    # Also defines the input space of the policy with: observation, desired_goal, reached_graph_counter
    # Creates a tianshou policy (see initializers/policy.initialize_policy) with the net_args, or a default MLP
    # puts the tianshou policy inside a Policy.goal_policy GoalPolicy
    # GoalPolicy is the interface through which collector calls get_action and trainer calls update()
    policy, rewtermdone = initialize_gc_policy(config, env, extractor, dynamics, net_args=net_args)
    policy.to(device)

    return dynamics, graph_encoding, rewtermdone, policy
